{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1972a352-a0e3-41b7-81dc-dd4ae2b890c3",
   "metadata": {},
   "source": [
    "## In-Context Learning, Chain-of-Thought, Reasoning"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05d999bc-83a3-454f-a8a4-44cbff1fcedc",
   "metadata": {},
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/microsoft/LLMLingua/blob/main/examples/CoT.ipynb\">\r\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\r\n",
    "</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe3ed1ce-d38d-4048-9db6-9707b55dc642",
   "metadata": {},
   "source": [
    "**In-Context Learning (ICL)** is a unique capability of large models, allowing Language Models (LLMs) to quickly learn relevant tasks from a few examples. Generally, ICL is used in combination with the Chain-of-Thought (CoT) approach, which involves describing the reasoning process in detail within the examples to enhance the LLMs' reasoning abilities. For instance, Yao et al.'s Complexity-Based Prompting improved GSM8K performance from 74.9 to 78.85 in GPT-3.5-Turbo-0301. However, this can also lead to increasingly lengthy prompts, such as the GSM8K prompt with a token count of **2,366**."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae003ead-2f07-44a4-b641-2e33be920dd9",
   "metadata": {},
   "source": [
    "<center><img width=\"800\" src=\"../images/LLMLingua_framework.png\"></center>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b39b33f-5860-4825-8f00-d60aed0dce86",
   "metadata": {},
   "source": [
    "To address this, we propose [**LLMLingua**](https://arxiv.org/abs/2310.05736), that uses a well-trained small language model after alignment, such as GPT2-small or LLaMA-7B, to detect the unimportant tokens in the prompt and enable inference with the compressed prompt in black-box LLMs, achieving up to **20x** compression with minimal performance loss."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18422597-687a-43aa-a6ed-ce6244d0eb55",
   "metadata": {},
   "source": [
    "### GSM8K"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51a7accd-5ec2-4ed2-9582-1afdb441a998",
   "metadata": {},
   "source": [
    "Next, we will demonstrate the use of LLMLingua on the GSM8K dataset, which effectively alleviates the \"lost in the middle\" issue. The original dataset can be found at https://github.com/FranxYao/chain-of-thought-hub/blob/main/gsm8k/lib_prompt/prompt_hardest.txt, which has 2,366 tokens and is an 8-shot setup."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a970a901-11bd-43af-a8bc-7fb2fc6a1a07",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Defaulting to user installation because normal site-packages is not writeable\n",
      "Requirement already satisfied: llmlingua in /home/hjiang/Code/github/LLMLingua (0.1.2)\n",
      "Requirement already satisfied: datasets in /home/hjiang/.local/lib/python3.9/site-packages (2.14.4)\n",
      "Requirement already satisfied: nltk in /home/hjiang/.local/lib/python3.9/site-packages (from llmlingua) (3.8.1)\n",
      "Requirement already satisfied: numpy in /home/hjiang/.local/lib/python3.9/site-packages (from llmlingua) (1.23.5)\n",
      "Requirement already satisfied: tiktoken in /home/hjiang/.local/lib/python3.9/site-packages (from llmlingua) (0.4.0)\n",
      "Requirement already satisfied: torch in /home/hjiang/.local/lib/python3.9/site-packages (from llmlingua) (1.13.1+cu116)\n",
      "Requirement already satisfied: transformers>=4.26.0 in /home/hjiang/.local/lib/python3.9/site-packages (from llmlingua) (4.34.1)\n",
      "Requirement already satisfied: pyarrow>=8.0.0 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (11.0.0)\n",
      "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (0.3.7)\n",
      "Requirement already satisfied: pandas in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (2.0.3)\n",
      "Requirement already satisfied: requests>=2.19.0 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (2.29.0)\n",
      "Requirement already satisfied: tqdm>=4.62.1 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (4.65.0)\n",
      "Requirement already satisfied: xxhash in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (3.3.0)\n",
      "Requirement already satisfied: multiprocess in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (0.70.15)\n",
      "Requirement already satisfied: fsspec[http]>=2021.11.1 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (2023.6.0)\n",
      "Requirement already satisfied: aiohttp in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (3.8.5)\n",
      "Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (0.16.4)\n",
      "Requirement already satisfied: packaging in /home/hjiang/.local/lib/python3.9/site-packages (from datasets) (23.0)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from datasets) (5.3.1)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (23.1.0)\n",
      "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (3.2.0)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (6.0.4)\n",
      "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (4.0.2)\n",
      "Requirement already satisfied: yarl<2.0,>=1.0 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (1.9.2)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (1.4.0)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /home/hjiang/.local/lib/python3.9/site-packages (from aiohttp->datasets) (1.3.1)\n",
      "Requirement already satisfied: filelock in /home/hjiang/.local/lib/python3.9/site-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (3.12.2)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/hjiang/.local/lib/python3.9/site-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (4.7.1)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets) (2.8)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/hjiang/.local/lib/python3.9/site-packages (from requests>=2.19.0->datasets) (1.26.16)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets) (2019.11.28)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /home/hjiang/.local/lib/python3.9/site-packages (from transformers>=4.26.0->llmlingua) (2023.6.3)\n",
      "Requirement already satisfied: tokenizers<0.15,>=0.14 in /home/hjiang/.local/lib/python3.9/site-packages (from transformers>=4.26.0->llmlingua) (0.14.1)\n",
      "Requirement already satisfied: safetensors>=0.3.1 in /home/hjiang/.local/lib/python3.9/site-packages (from transformers>=4.26.0->llmlingua) (0.3.1)\n",
      "Requirement already satisfied: click in /home/hjiang/.local/lib/python3.9/site-packages (from nltk->llmlingua) (8.1.6)\n",
      "Requirement already satisfied: joblib in /home/hjiang/.local/lib/python3.9/site-packages (from nltk->llmlingua) (1.3.1)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /home/hjiang/.local/lib/python3.9/site-packages (from pandas->datasets) (2.8.2)\n",
      "Requirement already satisfied: pytz>=2020.1 in /home/hjiang/.local/lib/python3.9/site-packages (from pandas->datasets) (2023.3)\n",
      "Requirement already satisfied: tzdata>=2022.1 in /home/hjiang/.local/lib/python3.9/site-packages (from pandas->datasets) (2023.3)\n",
      "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.14.0)\n",
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.1\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.9 -m pip install --upgrade pip\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# Install dependency.\n",
    "!pip install llmlingua datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "641235b6-71a5-4f2a-8eec-272c73931bef",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2023-10-30 09:15:31--  https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.110.133, ...\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 8464 (8.3K) [text/plain]\n",
      "Saving to: ‘prompt_hardest.txt’\n",
      "\n",
      "prompt_hardest.txt  100%[===================>]   8.27K  --.-KB/s    in 0s      \n",
      "\n",
      "2023-10-30 09:15:31 (78.8 MB/s) - ‘prompt_hardest.txt’ saved [8464/8464]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Download the original prompt and dataset\n",
    "from datasets import load_dataset\n",
    "\n",
    "!wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt\n",
    "prompt_complex = open(\"./prompt_hardest.txt\").read()\n",
    "gsm8k = load_dataset(\"gsm8k\", \"main\")\n",
    "gsm8k_test = gsm8k[\"test\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cbbbf3de-a9d6-46cf-afab-dcb72a6154ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Using the OAI\n",
    "import openai\n",
    "\n",
    "openai.api_key = \"<insert_openai_key>\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "46506810-8565-43da-984b-d862c56b49c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# or Using the AOAI\n",
    "import openai\n",
    "\n",
    "openai.api_key = \"<insert_openai_key>\"\n",
    "openai.api_base = \"https://xxxx.openai.azure.com/\"\n",
    "openai.api_type = \"azure\"\n",
    "openai.api_version = \"2023-05-15\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8676ffa-5117-44dc-9742-bb9ab1d56e0c",
   "metadata": {},
   "source": [
    "### Setup Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cc17bbc5-86cb-4d15-a730-955af85a10b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# select an example from GSM8K\n",
    "question, answer = [gsm8k_test[2][key] for key in [\"question\", \"answer\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "58718a19-cc4e-4002-a92a-58ea3de9c9d0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Question: Josh decides to try flipping a house.  He buys a house for $80,000 and then puts in $50,000 in repairs.  This increased the value of the house by 150%.  How much profit did he make?\n",
      "Answer: The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\n",
      "He increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\n",
      "So the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\n",
      "So he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000\n",
      "#### 70000\n"
     ]
    }
   ],
   "source": [
    "# Ground-truth Answer\n",
    "print(\"Question:\", question)\n",
    "print(\"Answer:\", answer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba1c6d52-dc87-434c-a41c-0bbc8a286504",
   "metadata": {},
   "source": [
    "#### The response of Original prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "3d441f10-c5c7-4d45-b09a-717e536b36bf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "    \"id\": \"cmpl-8FZvcX70FH7ck9c9MegWmnUocH0A0\",\n",
      "    \"object\": \"text_completion\",\n",
      "    \"created\": 1698723720,\n",
      "    \"model\": \"gpt-35-turbo\",\n",
      "    \"choices\": [\n",
      "        {\n",
      "            \"text\": \" \\nLet's think step by step\\nThe repairs increased the value of the house by 150% so that means it increased by 80,000*1.5=$<<80000*1.5=120000>>120,000\\nSo the total value of the house is 80,000+120,000=$<<80000+120000=200000>>200,000\\nHe spent 80,000+50,000=$<<80000+50000=130000>>130,000\\nSo he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000\\nThe answer is 70,000\",\n",
      "            \"index\": 0,\n",
      "            \"finish_reason\": \"stop\",\n",
      "            \"logprobs\": null\n",
      "        }\n",
      "    ],\n",
      "    \"usage\": {\n",
      "        \"prompt_tokens\": 2428,\n",
      "        \"completion_tokens\": 142,\n",
      "        \"total_tokens\": 2570\n",
      "    }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "# The response from original prompt\n",
    "import json\n",
    "\n",
    "instruction = \"Please reference the following examples to answer the math question,\\n\"\n",
    "prompt = instruction + prompt_complex + \"\\n\\nQuestion: \" + question\n",
    "\n",
    "request_data = {\n",
    "    \"prompt\": prompt,\n",
    "    \"max_tokens\": 400,\n",
    "    \"temperature\": 0,\n",
    "    \"top_p\": 1,\n",
    "    \"n\": 1,\n",
    "    \"stream\": False,\n",
    "    \"stop\": \"\\n\\n\",\n",
    "}\n",
    "response = openai.Completion.create(\n",
    "    \"gpt-3.5-turbo-0301\",\n",
    "    **request_data,\n",
    ")\n",
    "print(json.dumps(response, indent=4))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9aa90492-8ad1-4a89-85c5-26b8472f1ff0",
   "metadata": {},
   "source": [
    "#### The response of Compressed Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "fa638dec-c9ec-4dce-9dac-d768145de714",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8ec90053e7274da59973427652f879a1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hjiang/.local/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:362: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n",
      "/home/hjiang/.local/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:367: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# Setup LLMLingua\n",
    "from llmlingua import PromptCompressor\n",
    "\n",
    "llm_lingua = PromptCompressor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "5f61a186-6641-4118-ad04-5245a53b6d79",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "    \"compressed_prompt\": \"Question: Sam bought a dozen boxes, each with 30 highlighter pens inside, for $10 each. He reanged five of boxes into packages of sixlters each and sold them $3 per. He sold the rest theters separately at the of three pens $2. How much did make in total, dollars?\\nLets think step step\\nSam bought 1 boxes x00 oflters.\\nHe bought 12 00ters in total\\nSam then took5 boxes 6ters0ters\\nHe sold these boxes for 5 *5\\nAfterelling these  boxes there were 30330ters remaining\\nese form 330 /30 of three\\n sold each for2 each, so made * =0 from\\n total, he0 $15\\nSince his original1 he earned $120 = $115 in profit.\\nThe answer is 115\",\n",
      "    \"origin_tokens\": 2365,\n",
      "    \"compressed_tokens\": 174,\n",
      "    \"ratio\": \"13.6x\",\n",
      "    \"saving\": \", Saving $0.1 in GPT-4.\"\n",
      "}\n",
      "Response: {\n",
      "  \"id\": \"cmpl-8FZwYp1QIwiQs6pEhy2cRK6wnLnAO\",\n",
      "  \"object\": \"text_completion\",\n",
      "  \"created\": 1698723778,\n",
      "  \"model\": \"gpt-35-turbo\",\n",
      "  \"choices\": [\n",
      "    {\n",
      "      \"text\": \" \\n\\nThe repairs increased the value of the house by 150% so that means it increased by 80000*1.5=$<<80000*1.5=120000>>120,000\\nSo the total value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\\nThat means he made a profit of 200,000-80,000-50,000=$<<200000-80000-50000=70000>>70,000. Answer: \\\\boxed{70,000}.<|im_end|>\",\n",
      "      \"index\": 0,\n",
      "      \"finish_reason\": \"stop\",\n",
      "      \"logprobs\": null\n",
      "    }\n",
      "  ],\n",
      "  \"usage\": {\n",
      "    \"prompt_tokens\": 237,\n",
      "    \"completion_tokens\": 120,\n",
      "    \"total_tokens\": 357\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "# 174 tokens Compression, 13.6x\n",
    "compressed_prompt = llm_lingua.compress_prompt(\n",
    "    prompt_complex.split(\"\\n\\n\"),\n",
    "    instruction=\"\",\n",
    "    question=\"\",\n",
    "    target_token=200,\n",
    "    context_budget=\"*1.5\",\n",
    "    iterative_size=100,\n",
    ")\n",
    "\n",
    "instruction = \"Please reference the following examples to answer the math question,\\n\"\n",
    "prompt = (\n",
    "    instruction + compressed_prompt[\"compressed_prompt\"] + \"\\n\\nQuestion: \" + question\n",
    ")\n",
    "\n",
    "request_data = {\n",
    "    \"prompt\": prompt,\n",
    "    \"max_tokens\": 400,\n",
    "    \"temperature\": 0,\n",
    "    \"top_p\": 1,\n",
    "    \"n\": 1,\n",
    "    \"stream\": False,\n",
    "    \"stop\": \"\\r\\n\",\n",
    "}\n",
    "response = openai.Completion.create(\n",
    "    \"gpt-3.5-turbo-0301\",\n",
    "    **request_data,\n",
    ")\n",
    "print(json.dumps(compressed_prompt, indent=4))\n",
    "print(\"Response:\", response)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f89bb0f-7959-4a14-95be-dc80d88ce576",
   "metadata": {},
   "source": [
    "### Test in GSM8K test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "c1ac9bf5-23a9-446c-9394-8bb19aa1d89d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "\n",
    "def extract_ans(ans_model):\n",
    "    ans_model = ans_model.split(\"\\n\")\n",
    "    ans = []\n",
    "    residual = []\n",
    "    for li, al in enumerate(ans_model):\n",
    "        ans.append(al)\n",
    "        if \"answer is\" in al:\n",
    "            break\n",
    "    residual = list(ans_model[li + 1 :])\n",
    "    ans = \"\\n\".join(ans)\n",
    "    residual = \"\\n\".join(residual)\n",
    "    return ans, residual\n",
    "\n",
    "\n",
    "def parse_pred_ans(filename):\n",
    "    with open(filename) as fd:\n",
    "        lines = fd.readlines()\n",
    "    am, a = None, None\n",
    "    num_q, acc = 0, 0\n",
    "    current_mode = \"none\"\n",
    "    questions = []\n",
    "    ans_pred = []\n",
    "    ans_gold = []\n",
    "    for l in lines:\n",
    "        l = l.replace(\",\", \"\")\n",
    "        if l.startswith(\"Q: \"):\n",
    "            if am is not None and a is not None:\n",
    "                questions.append(q)\n",
    "                ans_pred.append(am)\n",
    "                ans_gold.append(a)\n",
    "                if test_answer(am, a):\n",
    "                    acc += 1\n",
    "            current_mode = \"q\"\n",
    "            q = l\n",
    "            num_q += 1\n",
    "        elif l.startswith(\"A_model:\"):\n",
    "            current_mode = \"am\"\n",
    "            am = l\n",
    "        elif l.startswith(\"A:\"):\n",
    "            current_mode = \"a\"\n",
    "            a = l\n",
    "        else:\n",
    "            if current_mode == \"q\":\n",
    "                q += l\n",
    "            elif current_mode == \"am\":\n",
    "                am += l\n",
    "            elif current_mode == \"a\":\n",
    "                a += l\n",
    "            else:\n",
    "                raise ValueError(current_mode)\n",
    "\n",
    "    questions.append(q)\n",
    "    ans_pred.append(am)\n",
    "    ans_gold.append(a)\n",
    "    if test_answer(am, a):\n",
    "        acc += 1\n",
    "    print(\"num_q %d correct %d ratio %.4f\" % (num_q, acc, float(acc / num_q)))\n",
    "    return questions, ans_pred, ans_gold\n",
    "\n",
    "\n",
    "def get_result(text: str):\n",
    "    pattern = \"\\d*\\.?\\d+\"\n",
    "    res = re.findall(pattern, text)\n",
    "    return res[-1] if res else \"\"\n",
    "\n",
    "\n",
    "def test_answer(pred_str, ans_str):\n",
    "    pred, gold = get_result(pred_str), get_result(ans_str)\n",
    "    return pred == gold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "cb209d5a-f822-4734-afc5-dafc07cc1bbc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1319/1319 [47:55<00:00,  2.18s/it] \n"
     ]
    }
   ],
   "source": [
    "# Test in GSM8K test set\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "\n",
    "os.makedirs(\"outputs\", exist_ok=True)\n",
    "i = 0\n",
    "\n",
    "compressed_prompt = llm_lingua.compress_prompt(\n",
    "    prompt_complex.split(\"\\n\\n\"),\n",
    "    instruction=\"\",\n",
    "    question=\"\",\n",
    "    target_token=200,\n",
    "    context_budget=\"*1.5\",\n",
    "    iterative_size=100,\n",
    ")\n",
    "\n",
    "for q, a in tqdm(\n",
    "    zip(gsm8k_test[\"question\"], gsm8k_test[\"answer\"]), total=len(gsm8k_test[\"question\"])\n",
    "):\n",
    "    instruction = (\n",
    "        \"Please reference the following examples to answer the math question,\\n\"\n",
    "    )\n",
    "    prompt = (\n",
    "        instruction\n",
    "        + compressed_prompt[\"compressed_prompt\"]\n",
    "        + \"\\n\\nQuestion: \"\n",
    "        + q\n",
    "        + \"\\n\"\n",
    "    )\n",
    "\n",
    "    request_data = {\n",
    "        \"prompt\": prompt,\n",
    "        \"max_tokens\": 400,\n",
    "        \"temperature\": 0,\n",
    "        \"top_p\": 1,\n",
    "        \"n\": 1,\n",
    "        \"stream\": False,\n",
    "    }\n",
    "    response = openai.Completion.create(\n",
    "        \"gpt-3.5-turbo-0301\",\n",
    "        **request_data,\n",
    "    )\n",
    "    ans_model = response[\"choices\"][0][\"text\"]\n",
    "    ans_, residual = extract_ans(ans_model)\n",
    "    with open(\"outputs/test_gpt_3.5_turbo_LLMLingua_174.txt\", \"a\") as fd:\n",
    "        fd.write(\n",
    "            \"Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n\"\n",
    "            % (q, ans_.replace(\"Q:\", \"\").replace(\"A:\", \"\"), a)\n",
    "        )\n",
    "    i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "3a35d298-8596-4b92-8dda-8da4250c873c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 1319 correct 1032 ratio 0.7824\n"
     ]
    }
   ],
   "source": [
    "_ = parse_pred_ans(\"outputs/test_gpt_3.5_turbo_LLMLingua_174.txt\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
